{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Learn OpenAI Whisper - Chapter 8\n",
        "\n",
        "## Diarizing speech with WhisperX and NVIDIA's NeMo\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-TSAhmGl5QplD9-NzWLOONkH43SEMwdF)\n",
        "\n",
        "This notebook offers a step-by-step guide to implementing speaker diarization using [WhisperX](https://replicate.com/dinozoiddev/whisperx) and NVIDIA's NeMo. WhisperX combines the power of `whisper-large-v3` with diarization upgrades, while NeMo provides a robust framework for speaker diarization. The goal is to demonstrate how these tools can be integrated to accurately identify \"who spoke when\" in an audio file, enhancing the capabilities of automated speech recognition (ASR) systems.\n",
        "\n",
        "### Notebook Structure\n",
        "\n",
        "The notebook is organized into key sections, each addressing a specific aspect of the diarization process:\n",
        "\n",
        "1. **Setting up the environment**: Install necessary dependencies, including Whisper, NeMo, and supporting libraries.\n",
        "\n",
        "2. **Streamlining the diarization workflow with helper functions**: Utility functions for audio manipulation, data processing, and speaker embedding generation.\n",
        "\n",
        "3. **Separating music from speech using Demucs**: Separate music from speech using Demucs and convert audio to mono format for NeMo compatibility.\n",
        "\n",
        "4. **Transcribing audio using WhisperX**: Transcribe audio using WhisperX, leveraging its high accuracy in speech recognition.\n",
        "\n",
        "5. **Aligning Transcriptions with Wav2Vec2**: Align the transcription with the original audio using Wav2Vec2, a neural network designed for speech recognition and alignment tasks.\n",
        "\n",
        "6. **Using NeMo's MSDD model for speaker diarization**: Employ NVIDIA's NeMo MSDD model to separate transcribed audio into segments based on speaker identity.\n",
        "\n",
        "7. **Mapping speakers to sentences according to timestamps**: Map identified speakers to corresponding sentences in the transcription based on timestamps.\n",
        "\n",
        "8. **Enhancing speaker attribution with punctuation-based realignment**: Refine speaker labels using punctuation, addressing scenarios where sentences are split between speakers or background comments occur.\n",
        "\n",
        "9. **Finalizing the diarization process**: : Cleanup, result export, and speaker name mapping.\n",
        "\n",
        "This notebook provides a comprehensive resource for researchers, developers, and practitioners interested in exploring advanced diarization techniques within ASR systems. By integrating Whisper's transcription capabilities with NeMo's diarization framework, it offers a powerful solution for analyzing speech in audio recordings."
      ],
      "metadata": {
        "id": "AuMkmL00BdYi"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eCmjcOc9yEtQ"
      },
      "source": [
        "# Setting up the environment\n",
        "\n",
        "This section focuses on installing the necessary Python libraries and tools for the diarization process:\n",
        "\n",
        "- **whisperX**: An enhanced extension of OpenAI's Whisper model, providing additional functionality.\n",
        "- **nemo_toolkit[asr]**: NVIDIA's NeMo toolkit for automatic speech recognition, serving as the backbone for speaker diarization.\n",
        "- **demucs**: A library for music source separation, enabling the isolation of speech from background music during audio preprocessing.\n",
        "- **dora-search, lameenc, and openunmix**: Tools and libraries for audio processing, improving the quality and compatibility of audio data for diarization tasks.\n",
        "- **deepmultilingualpunctuation**: A library for adding punctuation to transcriptions, enhancing readability and structure.\n",
        "- **wget and pydub**: Utilities for downloading and manipulating audio files, simplifying audio data handling within the Python environment.\n",
        "\n",
        "These libraries work together to form a comprehensive toolset for processing audio files, transcribing speech, and performing speaker diarization. Each tool contributes to a specific aspect of the process, from audio data preparation to accurate transcription and speaker identification."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tn1c-CoDv2kw"
      },
      "outputs": [],
      "source": [
        "!pip install -q git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560\n",
        "!pip install -q --no-build-isolation nemo_toolkit[asr]==1.22.0\n",
        "!pip install -q --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs\n",
        "!pip install -q dora-search \"lameenc>=1.2\" openunmix\n",
        "!pip install -q deepmultilingualpunctuation\n",
        "!pip install -q wget pydub"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# RESTART\n",
        "![Restart_the_runtime_600x102.png](https://github.com/PacktPublishing/Learn-OpenAI-Whisper/raw/main/Chapter08/Restart_the_runtime_600x102.png)"
      ],
      "metadata": {
        "id": "Nk-2j7nT-BUa"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YzhncHP0ytbQ"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import wget\n",
        "from omegaconf import OmegaConf\n",
        "import json\n",
        "import shutil\n",
        "from faster_whisper import WhisperModel\n",
        "import whisperx\n",
        "import torch\n",
        "from pydub import AudioSegment\n",
        "from nemo.collections.asr.models.msdd_models import NeuralDiarizer\n",
        "from deepmultilingualpunctuation import PunctuationModel\n",
        "import re\n",
        "import logging\n",
        "import nltk\n",
        "from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH\n",
        "from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ONlFrSnD0FOp"
      },
      "outputs": [],
      "source": [
        "# Download sample multi-speaker audio file\n",
        "# Source: www.youtube.com/@Channel4News\n",
        "!wget -nv https://github.com/PacktPublishing/Learn-OpenAI-Whisper/raw/main/Chapter08/20150415-Fracking_the_debate.mp4\n",
        "\n",
        "# Name of the audio file\n",
        "audio_path = \"/content/20150415-Fracking_the_debate.mp4\"\n",
        "\n",
        "# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram\n",
        "enable_stemming = True\n",
        "\n",
        "# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large')\n",
        "whisper_model_name = \"large-v2\"\n",
        "\n",
        "# replaces numerical digits with their pronounciation, increases diarization accuracy\n",
        "suppress_numerals = True\n",
        "\n",
        "batch_size = 8\n",
        "\n",
        "language = None  # autodetect language\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jbsUt3SwyhjD"
      },
      "source": [
        "# 2. Streamlining the diarization workflow with helper functions\n",
        "\n",
        "This section introduces a set of helper functions designed to streamline the process of diarizing speech using Whisper and NeMo. These functions play a crucial role in managing audio data, aligning transcriptions with speaker identities, and enhancing the overall workflow. Here's a brief overview of the key functions:\n",
        "\n",
        "- **`create_config()`**: Sets up essential parameters for the diarization process by initializing and returning a configuration object.\n",
        "\n",
        "- **`get_word_ts_anchor()`**: Determines the anchor timestamp for words, enabling accurate alignment between spoken words and their timestamps in the audio.\n",
        "\n",
        "- **`get_words_speaker_mapping()`**: Maps each word in the transcription to the corresponding speaker based on the diarization results, ensuring accurate speaker attribution.\n",
        "\n",
        "- **`get_first_word_idx_of_sentence()`** and **`get_last_word_idx_of_sentence()`**: Identifies the indices of the first and last words in a sentence, facilitating sentence-level processing for speaker attribution and alignment.\n",
        "\n",
        "- **`get_realigned_ws_mapping_with_punctuation()`**: Enhances the accuracy of speaker attribution by adjusting the word-to-speaker mapping considering punctuation, particularly in complex conversational scenarios.\n",
        "\n",
        "- **`get_sentences_speaker_mapping()`**: Generates a mapping of entire sentences to speakers, providing a higher-level view of speaker contributions throughout the audio.\n",
        "\n",
        "- **`get_speaker_aware_transcript()`**: Produces a transcript that integrates textual content with speaker information, creating a cohesive and speaker-aware format.\n",
        "\n",
        "- **`format_timestamp()`** and **`write_srt()`**: Converts timestamps into a human-readable format and outputs the diarization results in SubRip Text (SRT) format for subtitles or detailed analysis.\n",
        "\n",
        "- **`find_numeral_symbol_tokens()`**, **`_get_next_start_timestamp()`**, and **`filter_missing_timestamps()`**: Assist in processing numerical data, ensuring continuity in timestamp sequences, and maintaining the integrity of temporal information.\n",
        "\n",
        "- **`cleanup()`**, **`process_language_arg()`**, **`transcribe()`**, and **`transcribe_batched()`**: Handle temporary file cleanup, language argument processing, audio transcription using Whisper, and batch processing for efficient transcription.\n",
        "\n",
        "These functions collectively form the backbone of the diarization workflow, enabling seamless integration of Whisper's transcription capabilities with NeMo's advanced diarization features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Se6Hc7CZygxu"
      },
      "outputs": [],
      "source": [
        "punct_model_langs = [\n",
        "    \"en\",\n",
        "    \"fr\",\n",
        "    \"de\",\n",
        "    \"es\",\n",
        "    \"it\",\n",
        "    \"nl\",\n",
        "    \"pt\",\n",
        "    \"bg\",\n",
        "    \"pl\",\n",
        "    \"cs\",\n",
        "    \"sk\",\n",
        "    \"sl\",\n",
        "]\n",
        "wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(\n",
        "    DEFAULT_ALIGN_MODELS_HF.keys()\n",
        ")\n",
        "\n",
        "whisper_langs = sorted(LANGUAGES.keys()) + sorted(\n",
        "    [k.title() for k in TO_LANGUAGE_CODE.keys()]\n",
        ")\n",
        "\n",
        "\n",
        "def create_config(output_dir):\n",
        "    DOMAIN_TYPE = \"telephonic\"  # Can be meeting, telephonic, or general based on domain type of the audio file\n",
        "    CONFIG_FILE_NAME = f\"diar_infer_{DOMAIN_TYPE}.yaml\"\n",
        "    CONFIG_URL = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}\"\n",
        "    MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)\n",
        "    if not os.path.exists(MODEL_CONFIG):\n",
        "        MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)\n",
        "\n",
        "    config = OmegaConf.load(MODEL_CONFIG)\n",
        "\n",
        "    data_dir = os.path.join(output_dir, \"data\")\n",
        "    os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "    meta = {\n",
        "        \"audio_filepath\": os.path.join(output_dir, \"mono_file.wav\"),\n",
        "        \"offset\": 0,\n",
        "        \"duration\": None,\n",
        "        \"label\": \"infer\",\n",
        "        \"text\": \"-\",\n",
        "        \"rttm_filepath\": None,\n",
        "        \"uem_filepath\": None,\n",
        "    }\n",
        "    with open(os.path.join(data_dir, \"input_manifest.json\"), \"w\") as fp:\n",
        "        json.dump(meta, fp)\n",
        "        fp.write(\"\\n\")\n",
        "\n",
        "    pretrained_vad = \"vad_multilingual_marblenet\"\n",
        "    pretrained_speaker_model = \"titanet_large\"\n",
        "    config.num_workers = 0  # Workaround for multiprocessing hanging with ipython issue\n",
        "    config.diarizer.manifest_filepath = os.path.join(data_dir, \"input_manifest.json\")\n",
        "    config.diarizer.out_dir = (\n",
        "        output_dir  # Directory to store intermediate files and prediction outputs\n",
        "    )\n",
        "\n",
        "    config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
        "    config.diarizer.oracle_vad = (\n",
        "        False  # compute VAD provided with model_path to vad config\n",
        "    )\n",
        "    config.diarizer.clustering.parameters.oracle_num_speakers = False\n",
        "\n",
        "    # Here, we use our in-house pretrained NeMo VAD model\n",
        "    config.diarizer.vad.model_path = pretrained_vad\n",
        "    config.diarizer.vad.parameters.onset = 0.8\n",
        "    config.diarizer.vad.parameters.offset = 0.6\n",
        "    config.diarizer.vad.parameters.pad_offset = -0.05\n",
        "    config.diarizer.msdd_model.model_path = (\n",
        "        \"diar_msdd_telephonic\"  # Telephonic speaker diarization model\n",
        "    )\n",
        "\n",
        "    return config\n",
        "\n",
        "\n",
        "def get_word_ts_anchor(s, e, option=\"start\"):\n",
        "    if option == \"end\":\n",
        "        return e\n",
        "    elif option == \"mid\":\n",
        "        return (s + e) / 2\n",
        "    return s\n",
        "\n",
        "\n",
        "def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option=\"start\"):\n",
        "    s, e, sp = spk_ts[0]\n",
        "    wrd_pos, turn_idx = 0, 0\n",
        "    wrd_spk_mapping = []\n",
        "    for wrd_dict in wrd_ts:\n",
        "        ws, we, wrd = (\n",
        "            int(wrd_dict[\"start\"] * 1000),\n",
        "            int(wrd_dict[\"end\"] * 1000),\n",
        "            wrd_dict[\"word\"],\n",
        "        )\n",
        "        wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)\n",
        "        while wrd_pos > float(e):\n",
        "            turn_idx += 1\n",
        "            turn_idx = min(turn_idx, len(spk_ts) - 1)\n",
        "            s, e, sp = spk_ts[turn_idx]\n",
        "            if turn_idx == len(spk_ts) - 1:\n",
        "                e = get_word_ts_anchor(ws, we, option=\"end\")\n",
        "        wrd_spk_mapping.append(\n",
        "            {\"word\": wrd, \"start_time\": ws, \"end_time\": we, \"speaker\": sp}\n",
        "        )\n",
        "    return wrd_spk_mapping\n",
        "\n",
        "\n",
        "sentence_ending_punctuations = \".?!\"\n",
        "\n",
        "\n",
        "def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):\n",
        "    is_word_sentence_end = (\n",
        "        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
        "    )\n",
        "    left_idx = word_idx\n",
        "    while (\n",
        "        left_idx > 0\n",
        "        and word_idx - left_idx < max_words\n",
        "        and speaker_list[left_idx - 1] == speaker_list[left_idx]\n",
        "        and not is_word_sentence_end(left_idx - 1)\n",
        "    ):\n",
        "        left_idx -= 1\n",
        "\n",
        "    return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1\n",
        "\n",
        "\n",
        "def get_last_word_idx_of_sentence(word_idx, word_list, max_words):\n",
        "    is_word_sentence_end = (\n",
        "        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
        "    )\n",
        "    right_idx = word_idx\n",
        "    while (\n",
        "        right_idx < len(word_list)\n",
        "        and right_idx - word_idx < max_words\n",
        "        and not is_word_sentence_end(right_idx)\n",
        "    ):\n",
        "        right_idx += 1\n",
        "\n",
        "    return (\n",
        "        right_idx\n",
        "        if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)\n",
        "        else -1\n",
        "    )\n",
        "\n",
        "\n",
        "def get_realigned_ws_mapping_with_punctuation(\n",
        "    word_speaker_mapping, max_words_in_sentence=50\n",
        "):\n",
        "    is_word_sentence_end = (\n",
        "        lambda x: x >= 0\n",
        "        and word_speaker_mapping[x][\"word\"][-1] in sentence_ending_punctuations\n",
        "    )\n",
        "    wsp_len = len(word_speaker_mapping)\n",
        "\n",
        "    words_list, speaker_list = [], []\n",
        "    for k, line_dict in enumerate(word_speaker_mapping):\n",
        "        word, speaker = line_dict[\"word\"], line_dict[\"speaker\"]\n",
        "        words_list.append(word)\n",
        "        speaker_list.append(speaker)\n",
        "\n",
        "    k = 0\n",
        "    while k < len(word_speaker_mapping):\n",
        "        line_dict = word_speaker_mapping[k]\n",
        "        if (\n",
        "            k < wsp_len - 1\n",
        "            and speaker_list[k] != speaker_list[k + 1]\n",
        "            and not is_word_sentence_end(k)\n",
        "        ):\n",
        "            left_idx = get_first_word_idx_of_sentence(\n",
        "                k, words_list, speaker_list, max_words_in_sentence\n",
        "            )\n",
        "            right_idx = (\n",
        "                get_last_word_idx_of_sentence(\n",
        "                    k, words_list, max_words_in_sentence - k + left_idx - 1\n",
        "                )\n",
        "                if left_idx > -1\n",
        "                else -1\n",
        "            )\n",
        "            if min(left_idx, right_idx) == -1:\n",
        "                k += 1\n",
        "                continue\n",
        "\n",
        "            spk_labels = speaker_list[left_idx : right_idx + 1]\n",
        "            mod_speaker = max(set(spk_labels), key=spk_labels.count)\n",
        "            if spk_labels.count(mod_speaker) < len(spk_labels) // 2:\n",
        "                k += 1\n",
        "                continue\n",
        "\n",
        "            speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (\n",
        "                right_idx - left_idx + 1\n",
        "            )\n",
        "            k = right_idx\n",
        "\n",
        "        k += 1\n",
        "\n",
        "    k, realigned_list = 0, []\n",
        "    while k < len(word_speaker_mapping):\n",
        "        line_dict = word_speaker_mapping[k].copy()\n",
        "        line_dict[\"speaker\"] = speaker_list[k]\n",
        "        realigned_list.append(line_dict)\n",
        "        k += 1\n",
        "\n",
        "    return realigned_list\n",
        "\n",
        "\n",
        "def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):\n",
        "    sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak\n",
        "    s, e, spk = spk_ts[0]\n",
        "    prev_spk = spk\n",
        "\n",
        "    snts = []\n",
        "    snt = {\"speaker\": f\"Speaker {spk}\", \"start_time\": s, \"end_time\": e, \"text\": \"\"}\n",
        "\n",
        "    for wrd_dict in word_speaker_mapping:\n",
        "        wrd, spk = wrd_dict[\"word\"], wrd_dict[\"speaker\"]\n",
        "        s, e = wrd_dict[\"start_time\"], wrd_dict[\"end_time\"]\n",
        "        if spk != prev_spk or sentence_checker(snt[\"text\"] + \" \" + wrd):\n",
        "            snts.append(snt)\n",
        "            snt = {\n",
        "                \"speaker\": f\"Speaker {spk}\",\n",
        "                \"start_time\": s,\n",
        "                \"end_time\": e,\n",
        "                \"text\": \"\",\n",
        "            }\n",
        "        else:\n",
        "            snt[\"end_time\"] = e\n",
        "        snt[\"text\"] += wrd + \" \"\n",
        "        prev_spk = spk\n",
        "\n",
        "    snts.append(snt)\n",
        "    return snts\n",
        "\n",
        "\n",
        "def get_speaker_aware_transcript(sentences_speaker_mapping, f):\n",
        "    previous_speaker = sentences_speaker_mapping[0][\"speaker\"]\n",
        "    f.write(f\"{previous_speaker}: \")\n",
        "\n",
        "    for sentence_dict in sentences_speaker_mapping:\n",
        "        speaker = sentence_dict[\"speaker\"]\n",
        "        sentence = sentence_dict[\"text\"]\n",
        "\n",
        "        # If this speaker doesn't match the previous one, start a new paragraph\n",
        "        if speaker != previous_speaker:\n",
        "            f.write(f\"\\n\\n{speaker}: \")\n",
        "            previous_speaker = speaker\n",
        "\n",
        "        # No matter what, write the current sentence\n",
        "        f.write(sentence + \" \")\n",
        "\n",
        "\n",
        "def format_timestamp(\n",
        "    milliseconds: float, always_include_hours: bool = False, decimal_marker: str = \".\"\n",
        "):\n",
        "    assert milliseconds >= 0, \"non-negative timestamp expected\"\n",
        "\n",
        "    hours = milliseconds // 3_600_000\n",
        "    milliseconds -= hours * 3_600_000\n",
        "\n",
        "    minutes = milliseconds // 60_000\n",
        "    milliseconds -= minutes * 60_000\n",
        "\n",
        "    seconds = milliseconds // 1_000\n",
        "    milliseconds -= seconds * 1_000\n",
        "\n",
        "    hours_marker = f\"{hours:02d}:\" if always_include_hours or hours > 0 else \"\"\n",
        "    return (\n",
        "        f\"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}\"\n",
        "    )\n",
        "\n",
        "\n",
        "def write_srt(transcript, file):\n",
        "    \"\"\"\n",
        "    Write a transcript to a file in SRT format.\n",
        "\n",
        "    \"\"\"\n",
        "    for i, segment in enumerate(transcript, start=1):\n",
        "        # write srt lines\n",
        "        print(\n",
        "            f\"{i}\\n\"\n",
        "            f\"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> \"\n",
        "            f\"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\\n\"\n",
        "            f\"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\\n\",\n",
        "            file=file,\n",
        "            flush=True,\n",
        "        )\n",
        "\n",
        "\n",
        "def find_numeral_symbol_tokens(tokenizer):\n",
        "    numeral_symbol_tokens = [\n",
        "        -1,\n",
        "    ]\n",
        "    for token, token_id in tokenizer.get_vocab().items():\n",
        "        has_numeral_symbol = any(c in \"0123456789%$£\" for c in token)\n",
        "        if has_numeral_symbol:\n",
        "            numeral_symbol_tokens.append(token_id)\n",
        "    return numeral_symbol_tokens\n",
        "\n",
        "\n",
        "def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):\n",
        "    # if current word is the last word\n",
        "    if current_word_index == len(word_timestamps) - 1:\n",
        "        return word_timestamps[current_word_index][\"start\"]\n",
        "\n",
        "    next_word_index = current_word_index + 1\n",
        "    while current_word_index < len(word_timestamps) - 1:\n",
        "        if word_timestamps[next_word_index].get(\"start\") is None:\n",
        "            # if next word doesn't have a start timestamp\n",
        "            # merge it with the current word and delete it\n",
        "            word_timestamps[current_word_index][\"word\"] += (\n",
        "                \" \" + word_timestamps[next_word_index][\"word\"]\n",
        "            )\n",
        "\n",
        "            word_timestamps[next_word_index][\"word\"] = None\n",
        "            next_word_index += 1\n",
        "            if next_word_index == len(word_timestamps):\n",
        "                return final_timestamp\n",
        "\n",
        "        else:\n",
        "            return word_timestamps[next_word_index][\"start\"]\n",
        "\n",
        "\n",
        "def filter_missing_timestamps(\n",
        "    word_timestamps, initial_timestamp=0, final_timestamp=None\n",
        "):\n",
        "    # handle the first and last word\n",
        "    if word_timestamps[0].get(\"start\") is None:\n",
        "        word_timestamps[0][\"start\"] = (\n",
        "            initial_timestamp if initial_timestamp is not None else 0\n",
        "        )\n",
        "        word_timestamps[0][\"end\"] = _get_next_start_timestamp(\n",
        "            word_timestamps, 0, final_timestamp\n",
        "        )\n",
        "\n",
        "    result = [\n",
        "        word_timestamps[0],\n",
        "    ]\n",
        "\n",
        "    for i, ws in enumerate(word_timestamps[1:], start=1):\n",
        "        # if ws doesn't have a start and end\n",
        "        # use the previous end as start and next start as end\n",
        "        if ws.get(\"start\") is None and ws.get(\"word\") is not None:\n",
        "            ws[\"start\"] = word_timestamps[i - 1][\"end\"]\n",
        "            ws[\"end\"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)\n",
        "\n",
        "        if ws[\"word\"] is not None:\n",
        "            result.append(ws)\n",
        "    return result\n",
        "\n",
        "\n",
        "def cleanup(path: str):\n",
        "    \"\"\"path could either be relative or absolute.\"\"\"\n",
        "    # check if file or directory exists\n",
        "    if os.path.isfile(path) or os.path.islink(path):\n",
        "        # remove file\n",
        "        os.remove(path)\n",
        "    elif os.path.isdir(path):\n",
        "        # remove directory and all its content\n",
        "        shutil.rmtree(path)\n",
        "    else:\n",
        "        raise ValueError(\"Path {} is not a file or dir.\".format(path))\n",
        "\n",
        "\n",
        "def process_language_arg(language: str, model_name: str):\n",
        "    \"\"\"\n",
        "    Process the language argument to make sure it's valid and convert language names to language codes.\n",
        "    \"\"\"\n",
        "    if language is not None:\n",
        "        language = language.lower()\n",
        "    if language not in LANGUAGES:\n",
        "        if language in TO_LANGUAGE_CODE:\n",
        "            language = TO_LANGUAGE_CODE[language]\n",
        "        else:\n",
        "            raise ValueError(f\"Unsupported language: {language}\")\n",
        "\n",
        "    if model_name.endswith(\".en\") and language != \"en\":\n",
        "        if language is not None:\n",
        "            logging.warning(\n",
        "                f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
        "            )\n",
        "        language = \"en\"\n",
        "    return language\n",
        "\n",
        "\n",
        "def transcribe(\n",
        "    audio_file: str,\n",
        "    language: str,\n",
        "    model_name: str,\n",
        "    compute_dtype: str,\n",
        "    suppress_numerals: bool,\n",
        "    device: str,\n",
        "):\n",
        "    from faster_whisper import WhisperModel\n",
        "    from helpers import find_numeral_symbol_tokens, wav2vec2_langs\n",
        "\n",
        "    # Faster Whisper non-batched\n",
        "    # Run on GPU with FP16\n",
        "    whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype)\n",
        "\n",
        "    # or run on GPU with INT8\n",
        "    # model = WhisperModel(model_size, device=\"cuda\", compute_type=\"int8_float16\")\n",
        "    # or run on CPU with INT8\n",
        "    # model = WhisperModel(model_size, device=\"cpu\", compute_type=\"int8\")\n",
        "\n",
        "    if suppress_numerals:\n",
        "        numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)\n",
        "    else:\n",
        "        numeral_symbol_tokens = None\n",
        "\n",
        "    if language is not None and language in wav2vec2_langs:\n",
        "        word_timestamps = False\n",
        "    else:\n",
        "        word_timestamps = True\n",
        "\n",
        "    segments, info = whisper_model.transcribe(\n",
        "        audio_file,\n",
        "        language=language,\n",
        "        beam_size=5,\n",
        "        word_timestamps=word_timestamps,  # TODO: disable this if the language is supported by wav2vec2\n",
        "        suppress_tokens=numeral_symbol_tokens,\n",
        "        vad_filter=True,\n",
        "    )\n",
        "    whisper_results = []\n",
        "    for segment in segments:\n",
        "        whisper_results.append(segment._asdict())\n",
        "    # clear gpu vram\n",
        "    del whisper_model\n",
        "    torch.cuda.empty_cache()\n",
        "    return whisper_results, language\n",
        "\n",
        "\n",
        "def transcribe_batched(\n",
        "    audio_file: str,\n",
        "    language: str,\n",
        "    batch_size: int,\n",
        "    model_name: str,\n",
        "    compute_dtype: str,\n",
        "    suppress_numerals: bool,\n",
        "    device: str,\n",
        "):\n",
        "    import whisperx\n",
        "\n",
        "    # Faster Whisper batched\n",
        "    whisper_model = whisperx.load_model(\n",
        "        model_name,\n",
        "        device,\n",
        "        compute_type=compute_dtype,\n",
        "        asr_options={\"suppress_numerals\": suppress_numerals},\n",
        "    )\n",
        "    audio = whisperx.load_audio(audio_file)\n",
        "    result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
        "    del whisper_model\n",
        "    torch.cuda.empty_cache()\n",
        "    return result[\"segments\"], result[\"language\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ZS4xXmE2NGP"
      },
      "source": [
        "# 3. Separating music from speech using Demucs\n",
        "\n",
        "This section focuses on a crucial preprocessing step that enhances the clarity of speech before diarization. Demucs, a deep learning model for music source separation, is introduced as a powerful tool for isolating vocals from complex audio tracks.\n",
        "\n",
        "Separating music from speech is essential when dealing with recordings containing background music or non-speech elements. By extracting the vocal component, the diarization system can more effectively analyze and attribute speech to the correct speakers, as the spectral and temporal characteristics of their speech signals become more pronounced and less obscured by music.\n",
        "\n",
        "Demucs employs a neural network trained to distinguish between different audio sources within a mixture. When applied to an audio file, it separates the vocal track from the instrumental, allowing subsequent tools like Whisper and NeMo to process the speech without interference from background music.\n",
        "\n",
        "This separation step not only benefits the accuracy of speaker diarization but also improves the performance of downstream tasks that require clean speech input, such as transcription and speech recognition. By incorporating Demucs into the preprocessing pipeline, the notebook ensures that the input to the diarization system is optimized for the best possible results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HKcgQUrAzsJZ"
      },
      "outputs": [],
      "source": [
        "if enable_stemming:\n",
        "    # Isolate vocals from the rest of the audio\n",
        "\n",
        "    return_code = os.system(\n",
        "        f'python3 -m demucs.separate -n htdemucs --two-stems=vocals \"{audio_path}\" -o \"temp_outputs\"'\n",
        "    )\n",
        "\n",
        "    if return_code != 0:\n",
        "        logging.warning(\"Source splitting failed, using original audio file.\")\n",
        "        vocal_target = audio_path\n",
        "    else:\n",
        "        vocal_target = os.path.join(\n",
        "            \"temp_outputs\",\n",
        "            \"htdemucs\",\n",
        "            os.path.splitext(os.path.basename(audio_path))[0],\n",
        "            \"vocals.wav\",\n",
        "        )\n",
        "else:\n",
        "    vocal_target = audio_path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UYg9VWb22Tz8"
      },
      "source": [
        "# 4. Transcribing audio using WhisperX\n",
        "\n",
        "This section focuses on leveraging WhisperX to accurately transcribe audio content. Whisper's robust capabilities enable it to convert speech to text across a wide range of languages and dialects.\n",
        "\n",
        "The transcription process involves processing the audio file through WhisperX to generate a set of text segments, each accompanied by timestamps indicating when the segment was spoken. This step lays the foundation for speaker diarization and further analysis by providing the necessary textual content..\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5-VKFn530oTl"
      },
      "outputs": [],
      "source": [
        "compute_type = \"float16\"\n",
        "# or run on GPU with INT8\n",
        "# compute_type = \"int8_float16\"\n",
        "# or run on CPU with INT8\n",
        "# compute_type = \"int8\"\n",
        "\n",
        "if batch_size != 0:\n",
        "    whisper_results, language = transcribe_batched(\n",
        "        vocal_target,\n",
        "        language,\n",
        "        batch_size,\n",
        "        whisper_model_name,\n",
        "        compute_type,\n",
        "        suppress_numerals,\n",
        "        device,\n",
        "    )\n",
        "else:\n",
        "    whisper_results, language = transcribe(\n",
        "        vocal_target,\n",
        "        language,\n",
        "        whisper_model_name,\n",
        "        compute_type,\n",
        "        suppress_numerals,\n",
        "        device,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PNjPQcFB2tyE"
      },
      "source": [
        "# 5. Aligning the transcription with the original audio using Wav2Vec2\n",
        "\n",
        "Wav2Vec2, a large-scale neural network, is employed here to learn speech representations that are useful for various speech processing tasks, including speech recognition and alignment.\n",
        "\n",
        "The code loads the Wav2Vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the `vocal_target` file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n",
        "\n",
        "By combining the outputs of Whisper and Wav2Vec2, the code produces a fully aligned transcription of the speech contained in the `vocal_target` file. This aligned transcription is valuable for downstream tasks such as speaker diarization, sentiment analysis, and language identification.\n",
        "\n",
        "If no Wav2Vec2 model is available for the specified language, word timestamps generated by Whisper will be used instead."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aAT7EgdD2tyE"
      },
      "outputs": [],
      "source": [
        "if language in wav2vec2_langs:\n",
        "    device = \"cuda\"\n",
        "    alignment_model, metadata = whisperx.load_align_model(\n",
        "        language_code=language, device=device\n",
        "    )\n",
        "    result_aligned = whisperx.align(\n",
        "        whisper_results, alignment_model, metadata, vocal_target, device\n",
        "    )\n",
        "    word_timestamps = filter_missing_timestamps(\n",
        "        result_aligned[\"word_segments\"],\n",
        "        initial_timestamp=whisper_results[0].get(\"start\"),\n",
        "        final_timestamp=whisper_results[-1].get(\"end\"),\n",
        "    )\n",
        "\n",
        "    # clear gpu vram\n",
        "    del alignment_model\n",
        "    torch.cuda.empty_cache()\n",
        "else:\n",
        "    assert batch_size == 0, (  # TODO: add a better check for word timestamps existence\n",
        "        f\"Unsupported language: {language}, use --batch_size to 0\"\n",
        "        \" to generate word timestamps using whisper directly and fix this error.\"\n",
        "    )\n",
        "    word_timestamps = []\n",
        "    for segment in whisper_results:\n",
        "        for word in segment[\"words\"]:\n",
        "            word_timestamps.append({\"word\": word[2], \"start\": word[0], \"end\": word[1]})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D1gkViCf2-CV"
      },
      "source": [
        "# 6. Using NeMo's MSDD model for speaker diarization\n",
        "This code employs the NVIDIA NeMo MSDD (Multi-scale Diarization Decoder) model to perform speaker diarization on an audio signal. Speaker diarization is the process of separating an audio signal into different segments based on who is speaking at any given time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nir92iyb2tyF"
      },
      "outputs": [],
      "source": [
        "# Convert audio to mono for NeMo combatibility\n",
        "sound = AudioSegment.from_file(vocal_target).set_channels(1)\n",
        "ROOT = os.getcwd()\n",
        "temp_path = os.path.join(ROOT, \"temp_outputs\")\n",
        "os.makedirs(temp_path, exist_ok=True)\n",
        "sound.export(os.path.join(temp_path, \"mono_file.wav\"), format=\"wav\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C7jIpBCH02RL"
      },
      "outputs": [],
      "source": [
        "# Initialize NeMo MSDD diarization model\n",
        "msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(\"cuda\")\n",
        "msdd_model.diarize()\n",
        "\n",
        "del msdd_model\n",
        "torch.cuda.empty_cache()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NmkZYaDAEOAg"
      },
      "source": [
        "# 7. Mapping speakers to sentences according to timestamps\n",
        "\n",
        "The code in this section reads the speaker labels and their corresponding timestamps from the output file generated by the NeMo MSDD model. It then uses the `get_words_speaker_mapping` function to map each word in the transcription to its respective speaker based on the timestamp information.\n",
        "\n",
        "This mapping process ensures that each word is attributed to the correct speaker, creating a comprehensive representation of who spoke what and when. The resulting `wsm` (Word-Speaker Mapping) variable contains a list of dictionaries, where each dictionary represents a word and its associated speaker, start time, and end time.\n",
        "\n",
        "By mapping speakers to sentences according to timestamps, the code lays the foundation for further analysis and processing of the diarized transcription."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E65LUGQe02zw"
      },
      "outputs": [],
      "source": [
        "# Reading timestamps <> Speaker Labels mapping\n",
        "\n",
        "speaker_ts = []\n",
        "with open(os.path.join(temp_path, \"pred_rttms\", \"mono_file.rttm\"), \"r\") as f:\n",
        "    lines = f.readlines()\n",
        "    for line in lines:\n",
        "        line_list = line.split(\" \")\n",
        "        s = int(float(line_list[5]) * 1000)\n",
        "        e = s + int(float(line_list[8]) * 1000)\n",
        "        speaker_ts.append([s, e, int(line_list[11].split(\"_\")[-1])])\n",
        "\n",
        "wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, \"start\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Ruxc8S1EXtW"
      },
      "source": [
        "# 8. Enhancing speaker attribution with punctuation-based realignment\n",
        "\n",
        "This section introduces a method for disambiguating speaker labels in cases where a sentence is split between two different speakers. By leveraging punctuation marks, the code determines the dominant speaker for each sentence in the transcription.\n",
        "\n",
        "For example, consider the following scenario:\n",
        "\n",
        "```\n",
        "Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are\n",
        "Speaker B: going to suck, right? So it's actually it hits you on both sides.\n",
        "```\n",
        "\n",
        "If a sentence is split between two speakers, the code identifies the most frequent speaker label for each word in the sentence and assigns that label to the entire sentence. This approach helps improve the accuracy of speaker diarization, particularly when the Whisper model may not account for fine utterances like \"hmm\" and \"yeah,\" but the NeMo Diarization Model does, resulting in inconsistent outputs.\n",
        "\n",
        "Additionally, the code handles situations where one speaker is delivering a monologue while other speakers make occasional comments in the background. In such cases, it ignores the comments and assigns the entire monologue to the speaker who is speaking the majority of the time.\n",
        "\n",
        "By realigning speech segments based on punctuation, the code provides a robust and reliable method for enhancing speaker attribution in the transcription."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pgfC5hA41BXu"
      },
      "outputs": [],
      "source": [
        "if language in punct_model_langs:\n",
        "    # restoring punctuation in the transcript to help realign the sentences\n",
        "    punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
        "\n",
        "    words_list = list(map(lambda x: x[\"word\"], wsm))\n",
        "\n",
        "    labled_words = punct_model.predict(words_list)\n",
        "\n",
        "    ending_puncts = \".?!\"\n",
        "    model_puncts = \".,;:!?\"\n",
        "\n",
        "    # We don't want to punctuate U.S.A. with a period. Right?\n",
        "    is_acronym = lambda x: re.fullmatch(r\"\\b(?:[a-zA-Z]\\.){2,}\", x)\n",
        "\n",
        "    for word_dict, labeled_tuple in zip(wsm, labled_words):\n",
        "        word = word_dict[\"word\"]\n",
        "        if (\n",
        "            word\n",
        "            and labeled_tuple[1] in ending_puncts\n",
        "            and (word[-1] not in model_puncts or is_acronym(word))\n",
        "        ):\n",
        "            word += labeled_tuple[1]\n",
        "            if word.endswith(\"..\"):\n",
        "                word = word.rstrip(\".\")\n",
        "            word_dict[\"word\"] = word\n",
        "\n",
        "else:\n",
        "    logging.warning(\n",
        "        f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n",
        "    )\n",
        "\n",
        "wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
        "ssm = get_sentences_speaker_mapping(wsm, speaker_ts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vF2QAtLOFvwZ"
      },
      "source": [
        "# 9. Finalizing the diarization process\n",
        "\n",
        "In this final section, the code performs essential cleanup tasks, exports the diarization results for further use, and replaces speaker IDs with their corresponding names. The main steps include:\n",
        "\n",
        "1. **Saving the Speaker-Aware Transcript**: The `get_speaker_aware_transcript` function is used to generate a transcript that incorporates both the textual content and the corresponding speaker information. This transcript is then saved to a file with the same name as the input audio file, but with a \".txt\" extension.\n",
        "\n",
        "2. **Exporting the Diarization Results in SRT Format**: The `write_srt` function is employed to export the diarization results in the SubRip Text (SRT) format. This format is commonly used for subtitles and includes speaker labels and precise timestamps for each utterance. The SRT file is saved with the same name as the input audio file, but with a \".srt\" extension.\n",
        "\n",
        "3. **Cleaning Up Temporary Files**: The `cleanup` function is called to remove any temporary files or directories created during the diarization process. This step ensures a clean and organized working environment, freeing up storage space and maintaining system efficiency.\n",
        "\n",
        "4. **Mapping Speaker IDs to Speaker Names**: The code reads the content of the previously saved speaker-aware transcript file and replaces the generic speaker IDs (e.g., \"Speaker 0\", \"Speaker 1\", \"Speaker 2\") with the actual names of the speakers.\n",
        "\n",
        "By completing these final steps, the diarization process is concluded, and the results are made available for further analysis, post-processing, or integration with other tools and workflows. The exported speaker-aware transcript, SRT file, and the transcript with mapped speaker names provide valuable insights into the content and structure of the audio recording, enabling a wide range of applications, such as content analysis, speaker identification, and subtitle generation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kFTyKI6B1MI0"
      },
      "outputs": [],
      "source": [
        "with open(f\"{os.path.splitext(audio_path)[0]}.txt\", \"w\", encoding=\"utf-8-sig\") as f:\n",
        "    get_speaker_aware_transcript(ssm, f)\n",
        "\n",
        "with open(f\"{os.path.splitext(audio_path)[0]}.srt\", \"w\", encoding=\"utf-8-sig\") as srt:\n",
        "    write_srt(ssm, srt)\n",
        "\n",
        "cleanup(temp_path)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Open the file\n",
        "with open(f\"{os.path.splitext(audio_path)[0]}.txt\", 'r') as f:\n",
        "    text = f.read()\n",
        "\n",
        "# Replace the speaker IDs with names\n",
        "text = text.replace('Speaker 0','Ewa Jasiewicz')\n",
        "text = text.replace('Speaker 1','Chris Faulkner')\n",
        "text = text.replace('Speaker 2','Matt Frei')\n",
        "\n",
        "# Write the file to disk\n",
        "with open(audio_path[:-4] + '-with-speakers-names.txt', 'w') as f:\n",
        "    f.write(text)"
      ],
      "metadata": {
        "id": "Iq2aJbBb6EZU"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}